import torch
from torchvision import datasets, transforms

def get_transforms_with_augment(dataset_name, use_augment=True, magnitude=4):
    if dataset_name == "MNIST":
        if use_augment:
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.ToPILImage(),
                transforms.RandAugment(num_ops=2, magnitude=magnitude, interpolation=transforms.InterpolationMode.NEAREST),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
        else:
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomCrop(28, padding=4),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
    elif dataset_name == "FashionMNIST":
        if use_augment:
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.ToPILImage(),
                transforms.RandAugment(num_ops=2, magnitude=magnitude, interpolation=transforms.InterpolationMode.NEAREST),
                transforms.ToTensor(),
                transforms.Normalize((0.2860,), (0.3530,))
            ])
        else:
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomCrop(28, padding=4),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.Normalize((0.2860,), (0.3530,))
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3530,))
        ])
        
    elif dataset_name == "CIFAR10":
        if use_augment:
            train_transform = transforms.Compose([
                transforms.RandAugment(num_ops=2, magnitude=magnitude),
                transforms.ToTensor(),
                transforms.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784))
            ])
        else:
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784))
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784))
        ])
        
    elif dataset_name == "CIFAR100":
        if use_augment:
            train_transform = transforms.Compose([
                transforms.RandAugment(num_ops=2, magnitude=magnitude),
                transforms.ToTensor(),
                transforms.Normalize((0.50707516, 0.48654887, 0.44091784), (0.26733429, 0.25643846, 0.27615047))
            ])
        else:
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.50707516, 0.48654887, 0.44091784), (0.26733429, 0.25643846, 0.27615047))
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.50707516, 0.48654887, 0.44091784), (0.26733429, 0.25643846, 0.27615047))
        ])
    
    elif dataset_name == "SVHN":
        if use_augment:
            train_transform = transforms.Compose([
                transforms.RandAugment(num_ops=2, magnitude=magnitude),
                transforms.ToTensor(),
                transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))
            ])
        else:
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomCrop(32, padding=4),
                transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))
        ])
    
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")
    
    return train_transform, test_transform

def get_dataset(args):
    use_augment = getattr(args, 'use_augment', True)
    magnitude = getattr(args, 'augment_magnitude', 4)
    
    train_transform, test_transform = get_transforms_with_augment(args.dataset, use_augment, magnitude)
    
    if args.dataset == "MNIST":
        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)
        return train_loader, test_loader
    
    elif args.dataset == "FashionMNIST":
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=test_transform)
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)
        return train_loader, test_loader
    
    elif args.dataset == 'CIFAR10':
        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False, pin_memory=True)
        return train_loader, test_loader
    
    elif args.dataset == 'CIFAR100':
        train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False, pin_memory=True)
        return train_loader, test_loader
    
    elif args.dataset == 'SVHN':
        train_dataset = datasets.SVHN(root='./data', split='train', download=True, transform=train_transform)
        test_dataset = datasets.SVHN(root='./data', split='test', download=True, transform=test_transform)
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False, pin_memory=True)
        return train_loader, test_loader
    
    else:
        raise ValueError(f"Dataset {args.dataset} not supported")

